In [50]:
from circuitsvis.attention import attention_heads
from fancy_einsum import einsum
from functools import partial
from IPython.display import HTML, IFrame
from jaxtyping import Float
from os import environ
from transformer_lens import ActivationCache, HookedTransformer
from typing import List, Optional, Union
import circuitsvis as cv
import dotenv
import einops
import numpy as np
import plotly.express as px
import plotly.io as pio
import plotly.graph_objs as go
import plotly
import torch
import tqdm.auto as tqdm
import transformer_lens
import transformer_lens.utils as utils
# from utils import imshow, line, scatter

plotly.offline.init_notebook_mode()

dotenv.load_dotenv('.env', override=True)
print(f'{environ.get("PYTORCH_ENABLE_MPS_FALLBACK")=}')

torch.set_grad_enabled(False)

DEVICE = utils.get_device()
print(f'{DEVICE=}')

PRETRAINED_MODEL = 'gpt2-small'  # gpt2-small or gpt2-medium
print(f'{PRETRAINED_MODEL=}')

model = transformer_lens.HookedTransformer.from_pretrained(
    PRETRAINED_MODEL, center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    refactor_factored_attn_matrices=True, device=DEVICE)

print("test model generate:", model.generate(
    "The Space Needle is in the city of"))

print("print model structure", model)


def imshow(tensor, **kwargs):
    px.imshow(
        utils.to_numpy(tensor),
        color_continuous_midpoint=0.0,
        color_continuous_scale="RdBu",
        **kwargs,
    ).show()


def line(tensor, **kwargs):
    px.line(
        y=utils.to_numpy(tensor),
        **kwargs,
    ).show()


def scatter(x, y, xaxis="", yaxis="", caxis="", **kwargs):
    x = utils.to_numpy(x)
    y = utils.to_numpy(y)
    px.scatter(
        y=y,
        x=x,
        labels={"x": xaxis, "y": yaxis, "color": caxis},
        **kwargs,
    ).show()
environ.get("PYTORCH_ENABLE_MPS_FALLBACK")='1'
DEVICE=device(type='mps')
PRETRAINED_MODEL='gpt2-small'
/Users/lihenan/miniconda3/lib/python3.11/site-packages/huggingface_hub/file_download.py:1132: FutureWarning:

`resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.

Loaded pretrained model gpt2-small into HookedTransformer
100%|██████████| 10/10 [00:00<00:00, 18.29it/s]
test model generate: The Space Needle is in the city of Los Angeles now to part ways with its crowd manager
print model structure HookedTransformer(
  (embed): Embed()
  (hook_embed): HookPoint()
  (pos_embed): PosEmbed()
  (hook_pos_embed): HookPoint()
  (blocks): ModuleList(
    (0-11): 12 x TransformerBlock(
      (ln1): LayerNormPre(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (ln2): LayerNormPre(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (attn): Attention(
        (hook_k): HookPoint()
        (hook_q): HookPoint()
        (hook_v): HookPoint()
        (hook_z): HookPoint()
        (hook_attn_scores): HookPoint()
        (hook_pattern): HookPoint()
        (hook_result): HookPoint()
      )
      (mlp): MLP(
        (hook_pre): HookPoint()
        (hook_post): HookPoint()
      )
      (hook_attn_in): HookPoint()
      (hook_q_input): HookPoint()
      (hook_k_input): HookPoint()
      (hook_v_input): HookPoint()
      (hook_mlp_in): HookPoint()
      (hook_attn_out): HookPoint()
      (hook_mlp_out): HookPoint()
      (hook_resid_pre): HookPoint()
      (hook_resid_mid): HookPoint()
      (hook_resid_post): HookPoint()
    )
  )
  (ln_final): LayerNormPre(
    (hook_scale): HookPoint()
    (hook_normalized): HookPoint()
  )
  (unembed): Unembed()
)

In [ ]:
# what are all number tokens?
In [ ]:
prompts = [
    "123+456=",
    " 123+456=",
    " 123 +456=",
    " 123 + 456=",
    " 123 + 456 =",
    " 123 + 456 = ",
    "123 +456=",
    "123 + 456=",
    "123 + 456 =",
    "123+ 456=",
    "123+ 456 =",
    "123+ 456 = ",
    "123+456 =",
    "123+456 = ",
    "123+456= ",
]
answers = [  # each answer contains 2 tokens
    " 579",
]


# utils.test_prompt(prompt='123+123=', answer='579', model=model,
#                   prepend_bos=True, prepend_space_to_answer=False, print_details=False)

for p in prompts:
    for a in answers:
        print(f"prompt=|{p}|, answer=|{a}|")
        utils.test_prompt(prompt=p, answer=a, model=model, top_k=1)
prompt=|123 + 456 =|, answer=| 579|
Tokenized prompt: ['<|endoftext|>', '123', ' +', ' 4', '56', ' =']
Tokenized answer: [' 5', '79']
Performance on answer token:
Rank: 6        Logit: 11.39 Prob:  2.09% Token: | 5|
Top 0th token. Logit: 11.95 Prob:  3.67% Token: | 1|
Performance on answer token:
Rank: 27       Logit:  9.31 Prob:  0.51% Token: |79|
Top 0th token. Logit: 12.48 Prob: 12.00% Token: |.|
Ranks of the answer tokens: [(' 5', 6), ('79', 27)]


prompt=|123+ 456 =|, answer=| 579|
Tokenized prompt: ['<|endoftext|>', '123', '+', ' 4', '56', ' =']
Tokenized answer: [' 5', '79']
Performance on answer token:
Rank: 5        Logit:  9.91 Prob:  1.63% Token: | 5|
Top 0th token. Logit: 10.61 Prob:  3.29% Token: | 1|
Performance on answer token:
Rank: 80       Logit:  8.25 Prob:  0.26% Token: |79|
Top 0th token. Logit: 11.77 Prob:  8.72% Token: |.|
Ranks of the answer tokens: [(' 5', 5), ('79', 80)]
  • first output token logit diff from answer token logit = |11.39 - 11.95| = -0.56
  • second output token logit diff from answer token logit = |9.31 - 12.48| = -3.17
In [157]:
correct_wrong_answer_tokens = model.to_tokens(
    " 5 1", prepend_bos=False).to(DEVICE)
print(f'{correct_wrong_answer_tokens=}')
answer_residual_directions = model.tokens_to_residual_directions(
    correct_wrong_answer_tokens)
print("Answer residual directions shape:", answer_residual_directions.shape)
# logit_diff_directions = (
#     # token | 5| - token | 1|
#     answer_residual_directions[:, 0] - answer_residual_directions[:, 1]
# )
# print(f'{logit_diff_directions.shape=}')
print(f'{answer_residual_directions[0,0,0]=}')
# logit_diff_directions = (
#     answer_residual_directions[:, 0] - answer_residual_directions[:, 1]
# )
# print("Logit difference directions shape:", logit_diff_directions.shape)

# 768 refers to the dimensionality of the embedding and hidden states within the model

# cache syntax - resid_post is the residual stream at the end of the layer, -1 gets the final layer.
# The general syntax is [activation_name, layer_index, sub_layer_type].
input_tokens0 = model.to_tokens("123 + 456 =", prepend_bos=True)
output_logit0, cache0 = model.run_with_cache(input_tokens0)
print(f'logit for token | 5|: {output_logit0[0, -1, 642]=}')
print(f'logit for token | 1|: {output_logit0[0, -1, 352]=}')
assert (torch.eq(output_logit0[0, -1, 642].round(decimals=4), 11.3875).item())

input_tokens1 = model.to_tokens(
    "123 + 456 = 5", prepend_bos=True)  # advance 1 more token
output_logit1, cache1 = model.run_with_cache(input_tokens1)
print(f'logit for token |79|: {output_logit1[0, -1, 3720]=}')
assert (torch.eq(output_logit1[0, -1, 3720].round(decimals=4), 9.3148).item())

final_residual_stream0 = cache0["resid_post", -1]  # shape [1, 6, 768]
final_token_residual_stream = final_residual_stream0[:, -1, :]
# Apply LayerNorm scaling
# pos_slice is the subset of the positions we take - here the final token of each prompt
scaled_final_token_residual_stream = cache0.apply_ln_to_stack(
    final_token_residual_stream, layer=-1, pos_slice=-1
)
print(f'{scaled_final_token_residual_stream.shape=}')
print(f'{scaled_final_token_residual_stream[0,0]=}')
print(f'{cache0["ln_final.hook_normalized"].shape=}')
print(f'{cache0["ln_final.hook_normalized"][0,-1,0]=}')

calculated_logit_diff = cache0["ln_final.hook_normalized"][:, -1, :]
calculated_logit_diff = calculated_logit_diff @ model.unembed.W_U + model.unembed.b_U
print(f'{calculated_logit_diff.shape=}')
print(f'{calculated_logit_diff[0,642]=}')


def logits_to_ave_logit_diff(logits, answer_tokens):
    # Only the final logits are relevant for the answer
    final_logits = logits[:, -1, :]
    answer_logits = final_logits.gather(dim=-1, index=answer_tokens)
    print(f'{answer_logits=}')
    return answer_logits[:, 0] - answer_logits[:, 1]


print(
    "Per prompt logit difference:",
    logits_to_ave_logit_diff(output_logit0, correct_wrong_answer_tokens)
    .detach()
    .cpu()
    .round(decimals=3),
)

# ============== logit lens

accumulated_residual, labels = cache0.accumulated_resid(
    layer=-1, incl_mid=True, pos_slice=-1, return_labels=True, apply_ln=True
)
print(f'{labels=}')
accumulated_residual = accumulated_residual @ model.unembed.W_U + model.unembed.b_U
print(f'{accumulated_residual.shape=}')
# line(
#     torch.stack((accumulated_residual[:, 0, 642],
#                 accumulated_residual[:, 0, 352])),
#     x=np.arange(model.cfg.n_layers * 2 + 1) / 2,
#     hover_name=labels,
#     title="Logit Difference From Accumulate Residual Stream",
# )

fig = go.Figure()
fig.add_trace(go.Scatter(
    x=np.arange(model.cfg.n_layers * 2 + 1) / 2,
    y=utils.to_numpy(accumulated_residual[:, 0, 642]),
))
fig.add_trace(go.Scatter(
    x=np.arange(model.cfg.n_layers * 2 + 1) / 2,
    y=utils.to_numpy(accumulated_residual[:, 0, 352]),
))
fig.add_trace(go.Scatter(
    x=np.arange(model.cfg.n_layers * 2 + 1) / 2,
    y=utils.to_numpy(accumulated_residual[:, 0, 3720]),
))
fig.show()

per_layer_residual, labels = cache0.decompose_resid(
    layer=-1, pos_slice=-1, return_labels=True, apply_ln=True
)
print(f'{per_layer_residual.shape=} {labels=}')
per_layer_residual = per_layer_residual @ model.unembed.W_U + model.unembed.b_U
fig = go.Figure()
fig.add_trace(go.Scatter(
    x=np.arange(model.cfg.n_layers * 2 + 1) / 2,
    y=utils.to_numpy(per_layer_residual[:, 0, 642]),
    text=labels,
))
fig.add_trace(go.Scatter(
    x=np.arange(model.cfg.n_layers * 2 + 1) / 2,
    y=utils.to_numpy(per_layer_residual[:, 0, 352]),
))
fig.add_trace(go.Scatter(
    x=np.arange(model.cfg.n_layers * 2 + 1) / 2,
    y=utils.to_numpy(per_layer_residual[:, 0, 3720]),
))
fig.add_trace(go.Scatter(
    x=np.arange(model.cfg.n_layers * 2 + 1) / 2,
    y=utils.to_numpy(per_layer_residual[:, 0, 41734]),
))
fig.show()

stack_head_result, labels = cache0.stack_head_results(
    layer=-1, pos_slice=-1, return_labels=True, apply_ln=True
)
stack_head_result = stack_head_result @ model.unembed.W_U + model.unembed.b_U
stack_head_result = stack_head_result.view(12, 12, -1)
px.imshow(
    utils.to_numpy(stack_head_result[:,:, 642] - stack_head_result[:,:, 352]),
    labels={"x": "Head", "y": "Layer"},
).show()
px.imshow(
    utils.to_numpy(stack_head_result[:,:, 352]),
    labels={"x": "Head", "y": "Layer"},
).show()
correct_wrong_answer_tokens=tensor([[642, 352]], device='mps:0')
Answer residual directions shape: torch.Size([1, 2, 768])
answer_residual_directions[0,0,0]=tensor(-0.1862, device='mps:0')
logit for token | 5|: output_logit0[0, -1, 642]=tensor(11.3875, device='mps:0')
logit for token | 1|: output_logit0[0, -1, 352]=tensor(11.9513, device='mps:0')
logit for token |79|: output_logit1[0, -1, 3720]=tensor(9.3148, device='mps:0')
scaled_final_token_residual_stream.shape=torch.Size([1, 768])
scaled_final_token_residual_stream[0,0]=tensor(0.1353, device='mps:0')
cache0["ln_final.hook_normalized"].shape=torch.Size([1, 6, 768])
cache0["ln_final.hook_normalized"][0,-1,0]=tensor(0.1353, device='mps:0')
calculated_logit_diff.shape=torch.Size([1, 50257])
calculated_logit_diff[0,642]=tensor(11.3875, device='mps:0')
answer_logits=tensor([[11.3875, 11.9513]], device='mps:0')
Per prompt logit difference: tensor([-0.5640])
labels=['0_pre', '0_mid', '1_pre', '1_mid', '2_pre', '2_mid', '3_pre', '3_mid', '4_pre', '4_mid', '5_pre', '5_mid', '6_pre', '6_mid', '7_pre', '7_mid', '8_pre', '8_mid', '9_pre', '9_mid', '10_pre', '10_mid', '11_pre', '11_mid', 'final_post']
accumulated_residual.shape=torch.Size([25, 1, 50257])
per_layer_residual.shape=torch.Size([26, 1, 768]) labels=['embed', 'pos_embed', '0_attn_out', '0_mlp_out', '1_attn_out', '1_mlp_out', '2_attn_out', '2_mlp_out', '3_attn_out', '3_mlp_out', '4_attn_out', '4_mlp_out', '5_attn_out', '5_mlp_out', '6_attn_out', '6_mlp_out', '7_attn_out', '7_mlp_out', '8_attn_out', '8_mlp_out', '9_attn_out', '9_mlp_out', '10_attn_out', '10_mlp_out', '11_attn_out', '11_mlp_out']
Tried to stack head results when they weren't cached. Computing head results now
In [140]:
print(cache0['attn', 11][0, 8].shape) # 0 = batch, 8 = head

cv.attention.attention_patterns(
    attention=cache0['attn', 8][0], tokens=model.to_str_tokens('123 + 456 =')
)
torch.Size([6, 6])
Out[140]:
In [154]:
# prompts = [
#     "123 + 456 =",
# ]
tokens = model.to_tokens('223 + 456 =', prepend_bos=True)
print(f'{tokens.shape=}')
# print(model.to_single_token('579'))
# print(model.to_str_tokens('579', prepend_bos=False))


def patch_head_vector(
    corrupted_head_vector: Float[torch.Tensor, "batch pos head_index d_head"],
    hook,
    head_index,
    clean_cache,
):
    # print(hook.name)
    corrupted_head_vector[:, :, head_index, :] = clean_cache[hook.name][
        :, :, head_index, :
    ]
    return corrupted_head_vector


out = torch.zeros(model.cfg.n_layers, model.cfg.n_heads, device=DEVICE)
out1 = torch.zeros(model.cfg.n_layers, model.cfg.n_heads, device=DEVICE)

for layer in range(model.cfg.n_layers):
    for head_index in range(model.cfg.n_heads):
        hook_fn = partial(patch_head_vector,
                          head_index=head_index, clean_cache=cache0)
        patched_logits = model.run_with_hooks(
            tokens,
            fwd_hooks=[(utils.get_act_name("q", layer, "attn"), hook_fn)],
            return_type="logits",
        )
        out[layer, head_index] = patched_logits[0, -1, 642]
        out1[layer, head_index] = patched_logits[0, -1, 352]


px.imshow(
    utils.to_numpy(out-out1),
    color_continuous_scale="RdBu",
).show()
px.imshow(
    utils.to_numpy(out1),
    color_continuous_scale="RdBu",
).show()
# logits, cache = model.run_with_cache(tokens)
tokens.shape=torch.Size([1, 6])
In [164]:
tokens = model.to_tokens('123 + 456 =', prepend_bos=True)
logits, cache2 = model.run_with_cache(tokens, return_type='logits')
print(f'{logits[0,-1, 642]=} {logits[0,-1, 352]=}')


def amplify_fn(
    corrupted_head_vector: Float[torch.Tensor, "batch pos head_index d_head"],
    hook,
):
    if hook.name == 'blocks.0.attn.hook_v':
        print('f')
        corrupted_head_vector[:, :, 5, :]  = corrupted_head_vector[:, :, 5, :] * 0.5
        corrupted_head_vector = corrupted_head_vector * 1.1
    return corrupted_head_vector


patched_logits = model.run_with_hooks(
    tokens,
    fwd_hooks=[(utils.get_act_name("v", 0, "attn"), amplify_fn)],
    return_type="logits",
)
print(f'{patched_logits[0, -1, 642]=} {patched_logits[0, -1, 352]=}')
logits[0,-1, 642]=tensor(11.3875, device='mps:0') logits[0,-1, 352]=tensor(11.9513, device='mps:0')
f
patched_logits[0, -1, 642]=tensor(10.8292, device='mps:0') patched_logits[0, -1, 352]=tensor(11.3704, device='mps:0')

In [5]:
text = " 123456754234 + 84729123475"
tokens = model.to_tokens(text)
logits, activations = model.run_with_cache(tokens, remove_batch_dim=True)
attention_pattern = activations["pattern", 0, "attn"]
str_tokens = model.to_str_tokens(text)
cv.attention.attention_patterns(tokens=str_tokens, attention=attention_pattern)
Out[5]:
In [6]:
loss = model(tokens, return_type='loss')
print(f'{loss=}')
loss=tensor(6.8396, device='mps:0')
In [7]:
layer_to_ablate = 0
head_index_to_ablate = 1


def head_ablation_hook(value, hook):
    print(f"Shape of the value tensor: {value.shape} {hook}")
    value[:, :, 1, :] = 0.
    value[:, :, 2, :] = 0.
    return value


activations = None
with model.hooks(fwd_hooks=[(utils.get_act_name("attn", 0, "pattern"), head_ablation_hook)]):
    print(model)
    logits, activations = model.run_with_cache(tokens, remove_batch_dim=True)
    loss = model(tokens, return_type='loss')
    print(f'{loss=}')

attention_pattern = activations["pattern", 0, "attn"]
cv.attention.attention_patterns(tokens=str_tokens, attention=attention_pattern)

# loss = model.run_with_hooks(
#     tokens,
#     return_type='loss',
#     fwd_hooks=[(
#         utils.get_act_name("v", layer_to_ablate),
#         head_ablation_hook
#     )]
# )
# cv.attention.attention_patterns(tokens=str_tokens, attention=attention_pattern)
HookedTransformer(
  (embed): Embed()
  (hook_embed): HookPoint()
  (pos_embed): PosEmbed()
  (hook_pos_embed): HookPoint()
  (blocks): ModuleList(
    (0-11): 12 x TransformerBlock(
      (ln1): LayerNormPre(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (ln2): LayerNormPre(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (attn): Attention(
        (hook_k): HookPoint()
        (hook_q): HookPoint()
        (hook_v): HookPoint()
        (hook_z): HookPoint()
        (hook_attn_scores): HookPoint()
        (hook_pattern): HookPoint()
        (hook_result): HookPoint()
      )
      (mlp): MLP(
        (hook_pre): HookPoint()
        (hook_post): HookPoint()
      )
      (hook_attn_in): HookPoint()
      (hook_q_input): HookPoint()
      (hook_k_input): HookPoint()
      (hook_v_input): HookPoint()
      (hook_mlp_in): HookPoint()
      (hook_attn_out): HookPoint()
      (hook_mlp_out): HookPoint()
      (hook_resid_pre): HookPoint()
      (hook_resid_mid): HookPoint()
      (hook_resid_post): HookPoint()
    )
  )
  (ln_final): LayerNormPre(
    (hook_scale): HookPoint()
    (hook_normalized): HookPoint()
  )
  (unembed): Unembed()
)
Shape of the value tensor: torch.Size([1, 12, 12, 12]) HookPoint()
Shape of the value tensor: torch.Size([1, 12, 12, 12]) HookPoint()
loss=tensor(7.3264, device='mps:0')
Out[7]:
In [8]:
clean_prompt = "After John and Mary went to the store, Mary gave a bottle of milk to"
corrupted_prompt = "After John and Mary went to the store, John gave a bottle of milk to"
# clean_prompt = "The Space Needle is in the city of"
# corrupted_prompt = "Eiffel Tower is located in the city of"

clean_tokens = model.to_tokens(clean_prompt)
corrupted_tokens = model.to_tokens(corrupted_prompt)

def logits_to_logit_diff(logits, correct_answer, incorrect_answer):
    # model.to_single_token maps a string value of a single token to the token index for that token
    # If the string is not a single token, it raises an error.
    correct_index = model.to_single_token(correct_answer)
    incorrect_index = model.to_single_token(incorrect_answer)
    # print(f'{correct_index=}, {incorrect_index=}')
    # print(f'{logits.shape=}')
    return logits[0, -1, correct_index] - logits[0, -1, incorrect_index]

# We run on the clean prompt with the cache so we store activations to patch in later.
clean_logits, clean_cache = model.run_with_cache(clean_tokens)
clean_logit_diff = logits_to_logit_diff(clean_logits, correct_answer=" Seattle", incorrect_answer=" Paris")
print(f"Clean logit difference: {clean_logit_diff.item():.3f}")

corrupted_logits, corrupted_cache = model.run_with_cache(corrupted_tokens)
corrupted_logit_diff = logits_to_logit_diff(corrupted_logits, correct_answer=" Paris", incorrect_answer=" Seattle")
print(f"Corrupted logit difference: {corrupted_logit_diff.item():.3f}")

# str_tokens = model.to_str_tokens(clean_prompt)
# print(clean_cache)
# attention_pattern = clean_cache["pattern", 0, "attn"]
# cv.attention.attention_patterns(tokens=str_tokens, attention=attention_pattern)
Clean logit difference: -1.701
Corrupted logit difference: 2.327
In [9]:
# We define a residual stream patching hook
# We choose to act on the residual stream at the start of the layer, so we call it resid_pre
# The type annotations are a guide to the reader and are not necessary
def residual_stream_patching_hook(resid_pre, hook, position):
    # Each HookPoint has a name attribute giving the name of the hook.
    clean_resid_pre = clean_cache[hook.name]
    resid_pre[:, position, :] = clean_resid_pre[:, position, :]
    return resid_pre


# We make a tensor to store the results for each patching run. We put it on the model's device to avoid needing to move things between the GPU and CPU, which can be slow.
print(f'{clean_tokens.shape=}')
num_positions = len(clean_tokens[0])
print(f'{num_positions=}')
ioi_patching_result = torch.zeros(
    (model.cfg.n_layers, num_positions), device=model.cfg.device)

for layer in tqdm.tqdm(range(model.cfg.n_layers)):
    for position in range(num_positions):
        # Use functools.partial to create a temporary hook function with the position fixed
        temp_hook_fn = partial(
            residual_stream_patching_hook, position=position)
        # Run the model with the patching hook
        patched_logits = model.run_with_hooks(corrupted_tokens, fwd_hooks=[
            (utils.get_act_name("resid_pre", layer), temp_hook_fn)
        ])
        # Calculate the logit difference
        patched_logit_diff = logits_to_logit_diff(patched_logits, correct_answer=" Seattle", incorrect_answer=" Paris").detach()
        # Store the result, normalizing by the clean and corrupted logit difference so it's between 0 and 1 (ish)
        ioi_patching_result[layer, position] = (
            patched_logit_diff - corrupted_logit_diff)/(clean_logit_diff - corrupted_logit_diff)
clean_tokens.shape=torch.Size([1, 17])
num_positions=17
100%|██████████| 12/12 [00:08<00:00,  1.46it/s]
In [10]:
# Add the index to the end of the label, because plotly doesn't like duplicate labels
token_labels = [f"{token}_{index}" for index, token in enumerate(model.to_str_tokens(clean_tokens))]
imshow(ioi_patching_result, x=token_labels, xaxis="Position", yaxis="Layer", title="Normalized Logit Difference After Patching Residual Stream on the IOI Task")
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[10], line 3
      1 # Add the index to the end of the label, because plotly doesn't like duplicate labels
      2 token_labels = [f"{token}_{index}" for index, token in enumerate(model.to_str_tokens(clean_tokens))]
----> 3 imshow(ioi_patching_result, x=token_labels, xaxis="Position", yaxis="Layer", title="Normalized Logit Difference After Patching Residual Stream on the IOI Task")

Cell In[1], line 45, in imshow(tensor, **kwargs)
     44 def imshow(tensor, **kwargs):
---> 45     px.imshow(
     46         utils.to_numpy(tensor),
     47         color_continuous_midpoint=0.0,
     48         color_continuous_scale="RdBu",
     49         **kwargs,
     50     ).show()

TypeError: imshow() got an unexpected keyword argument 'xaxis'
In [ ]:
batch_size = 10
seq_len = 50
size = (batch_size, seq_len)
input_tensor = torch.randint(0, 10000, size)

random_tokens = input_tensor.to(model.cfg.device)
repeated_tokens = einops.repeat(random_tokens, "batch seq_len -> batch (2 seq_len)")
repeated_logits = model(repeated_tokens)
correct_log_probs = model.loss_fn(repeated_logits, repeated_tokens, per_token=True)
loss_by_position = einops.reduce(correct_log_probs, "batch position -> position", "mean")
line(loss_by_position, xaxis="Position", yaxis="Loss", title="Loss by position on random repeated tokens")
In [ ]:
# We make a tensor to store the induction score for each head. We put it on the model's device to avoid needing to move things between the GPU and CPU, which can be slow.
induction_score_store = torch.zeros(
    (model.cfg.n_layers, model.cfg.n_heads), device=model.cfg.device)


def induction_score_hook(pattern, hook):
    # We take the diagonal of attention paid from each destination position to source positions seq_len-1 tokens back
    # (This only has entries for tokens with index>=seq_len)
    induction_stripe = pattern.diagonal(dim1=-2, dim2=-1, offset=1-seq_len)
    # Get an average score per head
    induction_score = einops.reduce(
        induction_stripe, "batch head_index position -> head_index", "mean")
    # Store the result.
    induction_score_store[hook.layer(), :] = induction_score


# We make a boolean filter on activation names, that's true only on attention pattern names.
def pattern_hook_names_filter(name): return name.endswith("pattern")


model.run_with_hooks(
    repeated_tokens,
    return_type=None,  # For efficiency, we don't need to calculate the logits
    fwd_hooks=[(
        pattern_hook_names_filter,
        induction_score_hook
    )]
)

imshow(induction_score_store, xaxis="Head",
       yaxis="Layer", title="Induction Score by Head")
In [ ]:
induction_head_layer = 18
induction_head_index = 5
size = (1, 20)
input_tensor = torch.randint(1000, 10000, size)

single_random_sequence = input_tensor.to(model.cfg.device)
repeated_random_sequence = einops.repeat(single_random_sequence, "batch seq_len -> batch (2 seq_len)")
def visualize_pattern_hook(pattern, hook):
    display(
        cv.attention.attention_patterns(
            tokens=model.to_str_tokens(repeated_random_sequence), 
            attention=pattern[0, induction_head_index, :, :][None, :, :] # Add a dummy axis, as CircuitsVis expects 3D patterns.
        )
    )

model.run_with_hooks(
    repeated_random_sequence, 
    return_type=None, 
    fwd_hooks=[(
        utils.get_act_name("pattern", induction_head_layer), 
        visualize_pattern_hook
    )]
)
In [ ]:
test_prompt = "The quick brown fox jumped over the lazy dog"
print("Num tokens:", len(model.to_tokens(test_prompt)[0]))

def print_name_shape_hook_function(activation, hook):
    print(hook.name, activation.shape)

not_in_late_block_filter = lambda name: name.startswith("blocks.0.") or not name.startswith("blocks")

model.run_with_hooks(
    test_prompt,
    return_type=None,
    fwd_hooks=[(not_in_late_block_filter, print_name_shape_hook_function)],
)
In [ ]:
example_prompt = "1111 + 2222 ="
example_answer = " 3333"
utils.test_prompt(example_prompt, example_answer, model, prepend_bos=True)
In [ ]:
 
In [ ]:
 
In [ ]:
 
In [ ]: